import torch.nn as nn
import torch
device = torch.device("cpu")
MAX_LEN = 7
import torch.nn.functional as F


class Encoder(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(Encoder, self).__init__()
        self.hidden_size = hidden_size
        self.embedding = nn.Embedding(input_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size,batch_first=True)



    def forward(self, input, hidden):
        # 对输入的序列进行embdedding处理

        embedded = self.embedding(input)
        output = embedded
        # 在进行embedding处理之后，作为gru网络的输入，输入到gru，提取输入语句的特征。
        output, hidden = self.gru(output, hidden)




        return output, hidden[-1]

    def initHidden(self):
        return torch.zeros(1, 1, self.hidden_size, device=device)


# 定义Decoder方法类，这里的decoder过程是加上了attention机制
class AttentionDencoder(nn.Module):
    def __init__(self, hidden_size, output_size, dropout_p=0.5, max_length=MAX_LEN):
        super(AttentionDencoder, self).__init__()
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.dropout_p = dropout_p
        self.max_length = max_length

        self.embedding = nn.Embedding(self.output_size, self.hidden_size)
        self.attn = nn.Linear(self.hidden_size * 2, self.max_length)
        self.attn_combine = nn.Linear(self.hidden_size * 2, self.hidden_size)
        self.dropout = nn.Dropout(self.dropout_p)
        self.gru = nn.GRU(self.hidden_size, self.hidden_size,batch_first=True,dropout=self.dropout_p)
        self.out = nn.Linear(self.hidden_size, self.output_size)

    def forward(self, input, hidden, encoder_outputs):
        batch_size = input.size(0)

        embedded = self.embedding(input).view(batch_size,-1,self.hidden_size)
        embedded = self.dropout(embedded)
        # 使用softmax方法来计算出attention的权重值

        temp = torch.cat((embedded, hidden.unsqueeze(1)), -1)
        attn_weights = F.softmax(self.attn(temp))



        attn_applied = torch.bmm(attn_weights, encoder_outputs)

        output = torch.cat((embedded, attn_applied), -1)

        output = self.attn_combine(output)

        output = F.relu(output)
        output, hidden = self.gru(output, hidden.unsqueeze(0))

        output = self.out(output)

        return output, hidden[-1], attn_weights

    def initHidden(self):
        return torch.zeros(1, 1, self.hidden_size, device=device)